//
// Copyright © 2018-2021 Arm Limited.
// SPDX-License-Identifier: Apache-2.0
//

#include "SpaceToDepthPass.hpp"
#include "SramAllocator.hpp"

#include "Utils.hpp"

namespace ethosn
{
namespace support_library
{

using namespace utils;

std::pair<bool, uint32_t> SpaceToDepthPass::ChooseAndAllocateSram(const NodeId& nodeId,
                                                                  const HardwareCapabilities& capabilities,
                                                                  const TensorShape& inputShape,
                                                                  const TensorShape& outputShape,
                                                                  SramAllocator& sramAllocator,
                                                                  TensorShape& outIfmStripeShape,
                                                                  SpaceToDepthData& outSpaceToDepthData)
{
    std::pair<bool, uint32_t> intermediateAllocateResult;
    AllocationPreference outputAllocationPreference = AllocationPreference::Start;
    intermediateAllocateResult.first                = false;

    const uint32_t ifmWidth    = GetWidth(inputShape);
    const uint32_t ifmHeight   = GetHeight(inputShape);
    const uint32_t ifmChannels = GetChannels(inputShape);

    // Please refer to the Ethos-N78 DFC specification section 5.3.2 for a full description of the space to depth
    // algorithm. The following is a brief description of the algorithm.
    //
    // The space to depth operator is implemented using a multipass algorithm where the input data is transformed
    // into the correct shape using multiple DMA transfers. Tensor dimensions in this explanation will be written as
    // (height, width, channels)
    //
    // First the input IFM is subdivided so that the subtensors generated by each pass of the algorithm fit in SRAM.
    // The input IFM is subdivided along the Y-axis and the minimal size is (blockSize, ifmWidth, ifmChannels).
    //
    // In the first pass, DramToSram is used to read every blockSize'th row of the subdivided input tensor into a
    // separate subtensor, each s1 bytes per EMC in size. E.g. if blockSize is 2 then there will 2 subtensors, one
    // containing rows 0, 2, 4, 6.. of the input tensor and the other subtensor containing rows 1, 3, 5...
    // Each subtensor has the dimensions (ifmStripeHeight / blockSize, ifmWidth * ifmChannels / usedEmcs, usedEmcs).
    //
    // In the next pass, the subtensors are converted into column tensors with the width blockSize * ifmChannels using
    // a series of SramToSram transfers, each s2 bytes per EMC in size. This shape allows the DMA to merge the
    // different subtensors correctly when writing the final result into DRAM. The subtensors in this pass have the
    // dimensions (ifmWidth * ifmStripeHeight / blockSize^2, blockSize * ifmChannels / usedEmcs, usedEmcs).
    //
    // The last pass merges the subtensors into DRAM using a series of SramToDram transfers.

    const uint32_t numSrams  = capabilities.GetNumberOfSrams();
    const uint32_t blockSize = ifmWidth / GetWidth(outputShape);

    // usedEmcs must evenly divide ifmChannels * blockSize
    uint32_t usedEmcs = std::min(numSrams, blockSize * ifmChannels);
    while ((blockSize * ifmChannels) % usedEmcs != 0)
    {
        --usedEmcs;
    }

    uint32_t s1 = 0;
    uint32_t s2 = 0;

    // The IFM needs to be subdivided along the Y-axis if all subtensors don't fit in the SRAM at the same time.
    // The smallest possible subdivision is ifmWidth * blockSize * ifmChannels.
    // Try taking the whole size first, then move until we find something that works.
    outIfmStripeShape = inputShape;

    for (uint32_t divisor = 1; divisor <= ifmHeight && !intermediateAllocateResult.first; ++divisor)
    {
        if ((ifmHeight % divisor) != 0 || ((ifmHeight / divisor) % blockSize) != 0)
        {
            continue;
        }

        outIfmStripeShape[1] = ifmHeight / divisor;

        std::tie(s1, s2) = CalculateSpaceToDepthBlockSizes(outIfmStripeShape, usedEmcs, blockSize);

        intermediateAllocateResult = sramAllocator.Allocate(nodeId, CalculateSpaceToDepthSramUsage(blockSize, s1, s2),
                                                            outputAllocationPreference, "outputs attempt");
    }

    // Store important SRAM block information to avoid duplicate calculations in firmware
    if (intermediateAllocateResult.first)
    {
        outSpaceToDepthData.m_UsedEmcs          = usedEmcs;
        outSpaceToDepthData.m_Intermediate1Size = s1;
        outSpaceToDepthData.m_Intermediate2Size = s2;
    }

    return intermediateAllocateResult;
}

std::unique_ptr<ethosn::support_library::SpaceToDepthPass> SpaceToDepthPass::CreateGreedily(
    const HardwareCapabilities& capabilities, size_t size, Node* firstNode, SramAllocator& sramAllocator)
{
    if (dynamic_cast<SpaceToDepthNode*>(firstNode))
    {
        if (firstNode->GetInputLocation(0) != BufferLocation::Dram)
        {
            // The input must reside in DRAM
            firstNode->GetInput(0)->GetSource()->SetFixGraphLocationHint(LocationHint::RequireDram);
            return std::unique_ptr<SpaceToDepthPass>();
        }

        TensorShape stripeShape;
        SpaceToDepthData spaceToDepthData;

        std::pair<bool, uint32_t> allocationResult =
            ChooseAndAllocateSram(firstNode->GetId(), capabilities, firstNode->GetInputShape(0), firstNode->GetShape(),
                                  sramAllocator, stripeShape, spaceToDepthData);

        if (allocationResult.first)
        {
            // SpaceToDepth's OFM is always in DRAM. Therefore it is safe to free the allocated SRAM.
            sramAllocator.Free(firstNode->GetId(), allocationResult.second);

            return std::make_unique<SpaceToDepthPass>(capabilities, size, firstNode, allocationResult.second,
                                                      stripeShape, spaceToDepthData);
        }
    }

    return std::unique_ptr<SpaceToDepthPass>();
}

SpaceToDepthPass::SpaceToDepthPass(const HardwareCapabilities& capabilities,
                                   size_t id,
                                   Node* node,
                                   uint32_t workBuffersSramOffset,
                                   const TensorShape& ifmStripeShape,
                                   const SpaceToDepthData& spaceToDepthData)
    : Pass(capabilities, id)
    , m_Node(node)
    , m_WorkBuffersSramOffset(workBuffersSramOffset)
    , m_IfmStripeShape(ifmStripeShape)
    , m_SpaceToDepthData(spaceToDepthData)
{
    m_Nodes.push_back(node);
    m_Node->SetPass(this);
    m_Node->SetLocation(BufferLocation::Dram);
}

void SpaceToDepthPass::Generate(command_stream::CommandStreamBuffer& cmdStream,
                                BufferManager& bufferManager,
                                bool dumpRam)
{
    using namespace std;

    Pass::PreGenerate(cmdStream);

    uint32_t outputSize         = CalculateBufferSize(m_Node->GetShape(), m_Node->GetBufferFormat());
    uint32_t outputDramBufferId = bufferManager.AddDram(BufferType::Intermediate, outputSize);

    const uint32_t blockSize = GetWidth(m_Node->GetInputShape(0)) / GetWidth(m_Node->GetShape());

    TensorShape ofmStripeShape = { m_IfmStripeShape[0], m_IfmStripeShape[1] / blockSize,
                                   m_IfmStripeShape[2] / blockSize, m_IfmStripeShape[3] * blockSize * blockSize };

    command_stream::SpaceToDepth spaceToDepth;
    spaceToDepth.m_InputInfo().m_DataType()          = GetCommandDataType(m_Nodes.front()->GetInputDataType(0));
    spaceToDepth.m_InputInfo().m_DataFormat()        = m_Node->GetInputBufferFormat(0);
    spaceToDepth.m_InputInfo().m_TensorShape()       = m_Node->GetInputShape(0);
    spaceToDepth.m_InputInfo().m_SupertensorShape()  = m_Node->GetInputShape(0);
    spaceToDepth.m_InputInfo().m_SupertensorOffset() = { 0, 0, 0, 0 };
    spaceToDepth.m_InputInfo().m_DramBufferId()      = m_Node->GetInput(0)->GetSource()->GetBufferId();
    spaceToDepth.m_InputInfo().m_ZeroPoint() = static_cast<int16_t>(m_Node->GetInputQuantizationInfo(0).GetZeroPoint());
    spaceToDepth.m_InputInfo().m_DataLocation() = GetCommandDataLocation(m_Node->GetInputLocation(0));
    spaceToDepth.m_InputInfo().m_SramOffset()   = m_WorkBuffersSramOffset;
    spaceToDepth.m_InputInfo().m_StripeShape()  = m_IfmStripeShape;
    spaceToDepth.m_InputInfo().m_TileSize()     = 0;

    spaceToDepth.m_OutputInfo().m_DataType()          = GetCommandDataType(m_Node->GetDataType());
    spaceToDepth.m_OutputInfo().m_DataFormat()        = m_Node->GetBufferFormat();
    spaceToDepth.m_OutputInfo().m_TensorShape()       = m_Node->GetShape();
    spaceToDepth.m_OutputInfo().m_SupertensorShape()  = m_Node->GetShape();
    spaceToDepth.m_OutputInfo().m_SupertensorOffset() = { 0, 0, 0, 0 };
    spaceToDepth.m_OutputInfo().m_DramBufferId()      = outputDramBufferId;
    spaceToDepth.m_OutputInfo().m_ZeroPoint()    = static_cast<int16_t>(m_Node->GetQuantizationInfo().GetZeroPoint());
    spaceToDepth.m_OutputInfo().m_DataLocation() = GetCommandDataLocation(m_Node->GetLocation());
    spaceToDepth.m_OutputInfo().m_SramOffset()   = m_WorkBuffersSramOffset;
    spaceToDepth.m_OutputInfo().m_StripeShape()  = ofmStripeShape;
    spaceToDepth.m_OutputInfo().m_TileSize()     = 0;

    spaceToDepth.m_UsedEmcs()          = m_SpaceToDepthData.m_UsedEmcs;
    spaceToDepth.m_Intermediate1Size() = m_SpaceToDepthData.m_Intermediate1Size;
    spaceToDepth.m_Intermediate2Size() = m_SpaceToDepthData.m_Intermediate2Size;

    m_Node->SetBufferId(outputDramBufferId);

    cmdStream.EmplaceBack(spaceToDepth);

    Pass::PostGenerate(cmdStream, dumpRam);
}

PassStats SpaceToDepthPass::GetStats(const EstimationOptions& estimationOptions)
{
    (void)estimationOptions;

    PassStats perfData;

    const TensorShape& inputShape  = m_Nodes.front()->GetInputShape(0);
    const TensorShape& outputShape = m_Nodes.back()->GetShape();

    const uint32_t inputSize  = inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3];
    const uint32_t outputSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3];

    perfData.m_Input.m_MemoryStats.m_DramNonParallel  = inputSize;
    perfData.m_Output.m_MemoryStats.m_DramNonParallel = outputSize;

    return perfData;
}

ethosn::support_library::DotAttributes SpaceToDepthPass::GetDotAttributes()
{
    DotAttributes result = Pass::GetDotAttributes();
    result.m_Label       = "SpaceToDepthPass\n" + result.m_Label;
    return result;
}

}    // namespace support_library
}    // namespace ethosn
